library(keras)
library(tensorflow)
library(abind)
library(reticulate)
library(dplyr)
library(parallel)
library(doParallel)

在理解Instance Normalization (IN)過程中,無意又看到了Adaptive Instance Normalization (AdaIN),印象中以前在 Hung-yi Lee 老師的教學影片中有聽過,作用是調整輸出的global資訊。在圖像的風格應用上,若調整圖像mean 和 variance而不改變其distribution,那就能實現風格的移轉。另外,有些風格移轉的方法不能套用到其他圖像,必須重新訓練,這次實作的模型沒有這個困擾,只要訓練完畢,就能進行即時的任意風格移轉…

下圖是這次的網路架構,Encoder前後都有,事實上是指相同的一個,其工作是負責圖像特徵提取,這裡使用的是pre-trained VGG (當然你也可以使用如resnet50、inception等其他model)。Encoder不必訓練,這是比較要注意的地方,其他細節後續介紹。原文可參考 Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization


使用的資料集有兩個,圖像內容使用COCO資料集,風格使用WikiArt,訓練資料各留75000張,測試資料各留4433,其他參數設定如下:

TSB_PATH        = 'data/AdaINStyleTransfer/logs_r'
SAVE_PATH       = 'data/AdaINStyleTransfer/save'
PLOT_PATH       = 'data/AdaINStyleTransfer/plot'

CONTENT_PATH    = 'data/COCO/train2017'
STYLE_PATH      = 'data/WikiArt/train'

IMAGE_H         = 256L
IMAGE_W         = 256L
BATCH_SIZE      = 10L
STYLE_WEIGHT    = 2.5
EPSILON         = 1e-5
EPOCHS          = 10

#Getting the Data
content_filenames <- list.files(CONTENT_PATH, pattern = glob2rx("*.jpg"), full.names = T, recursive = T)
style_filenames <- list.files(STYLE_PATH, pattern = glob2rx("*.jpg"), full.names = T, recursive = T)

max_length <- pmin(length(content_filenames) , length(style_filenames))

content_filenames <- content_filenames[1:max_length]
style_filenames <- style_filenames[1:max_length]

set.seed(777)
index <- sample(max_length, 4433)

content_filenames_test <- content_filenames[index]
content_filenames <- content_filenames[-index]

style_filenames_test <- style_filenames[index]
style_filenames <- style_filenames[-index]

因為用到pre-trained VGG,在輸入圖像一般會使用imagenet_preprocess_input方法,要注意的是它會將RGB順序轉成BGR(因為其預設mode是使用caffe)及對每個通道進行zero-center。相同的,在輸出圖像時我們利用deprocess_image方法,將之前各通道扣除的平均數加回並恢復RGB順序。另外,建立data_generator來產生批次所需的資料…

#deprocess
deprocess_image <- function(x) {
  x <- x[1,,,]
  x[,,1] <- x[,,1] + 103.939
  x[,,2] <- x[,,2] + 116.779
  x[,,3] <- x[,,3] + 123.68
  #BGR -> RGB
  x <- x[,,c(3,2,1)]
  x[x > 255] <- 255
  x[x < 0] <- 0
  x[] <- as.integer(x)/255
  x
}

#generator
data_generator <- function(content_names, style_names, batch_size) {
  
  content_fullnames <- content_names
  content_fullnames_all <- content_names
  
  style_fullnames <- style_names
  style_fullnames_all <- style_names
  
  function() {
    # start new epoch, reset
    if (length(content_fullnames) < batch_size) {
      content_fullnames <<- content_fullnames_all
      style_fullnames <<- style_fullnames_all
    }
    idx_c <- sample(1:length(content_fullnames), batch_size)
    idx_s <- sample(1:length(style_fullnames), batch_size)
    
    batch_content_names <- content_fullnames[idx_c]
    content_fullnames <<- content_fullnames[-idx_c]
    
    batch_style_names <- style_fullnames[idx_s]
    style_fullnames <<- style_fullnames[-idx_s]
    
    content_style_batch <- foreach(i = 1:batch_size) %dopar% {
      # read img, preprocess
      img_content <- image_load(batch_content_names[i], target_size = c(IMAGE_H, IMAGE_W)) %>%
        image_to_array() %>%
        #default model=caffe, convert the images from RGB to BGR
        imagenet_preprocess_input()
      
      img_style <- image_load(batch_style_names[i], target_size = c(IMAGE_H, IMAGE_W)) %>%
        image_to_array() %>%
        #default model=caffe, convert the images from RGB to BGR
        imagenet_preprocess_input()
      
      content_style <- list(x = img_content, y = img_style)
    }
    
    content_style_batch <- purrr::transpose(content_style_batch)
    content_batch <- do.call(abind, c(content_style_batch$x, list(along = 0)))
    attr(content_batch, 'dimnames') <- NULL
    style_batch <- do.call(abind, c(content_style_batch$y, list(along = 0)))
    attr(style_batch, 'dimnames') <- NULL
    
    result <- list(list(content_batch, style_batch), style_batch)
    
    return(result)
  }
}


train_iterator <- py_iterator(data_generator(content_filenames, style_filenames, batch_size = BATCH_SIZE))
test_iterator <- py_iterator(data_generator(content_filenames_test, style_filenames_test, batch_size = BATCH_SIZE))

VGG19當base model,載入imagenet權重。在風格參照有4個layers,內容參照則只有1個layer…

#base model vgg19
input_img <- layer_input(shape = list(NULL, NULL, 3))
base_model <- application_vgg19(weights = 'imagenet',include_top = FALSE, input_tensor = input_img)

#layers
content_layer <- "block4_conv1"
style_layers <- c("block1_conv1", "block2_conv1", "block3_conv1", content_layer)

cs_layers_output <- lapply(style_layers, function(name) base_model$get_layer(name)$output)

建立ENCODE_CS,也就是架構圖中綠色的Encoder,其作用是提取特徵。也就是不管輸入的圖像是內容或風格,都可得到其對應的4個layers的outputs,而第4個也就是先前定義的content layer

#ENCODE_CONTENT_STYLE
ENCODE_CS <- keras_model(inputs = base_model$input, outputs = cs_layers_output)

content_input <- layer_input(shape = list(NULL, NULL, 3))
style_input <- layer_input(shape = list(NULL, NULL, 3))

content_encoded <- ENCODE_CS(content_input)
style_encoded <- ENCODE_CS(style_input)

enc_c <- content_encoded[[4]]
enc_s <- style_encoded[[4]]

建立AdaIN layer,作用是將content特徵進行標準化並打包成一個layer。可以看到其主要是利用其mean和variance進行調整,其distribution沒什麼改變(ex:常態的還是常態)。上述的enc_c、enc_s分別是架構中左邊Encoder VGG輸出的藍色、紅色箭頭,AdaIN的輸出即以下的target_features

#AdaIN
AdaIN <- function(args){
  
  c(content_features, style_features) %<-% args
  
  c(content_mean, content_var) %<-% tf$nn$moments(content_features, axes = c(1L, 2L), keep_dims = T)
  c(style_mean, style_var) %<-% tf$nn$moments(style_features, axes = c(1L, 2L), keep_dims = T)
  
  content_std <- tf$sqrt(tf$add(content_var, EPSILON))
  style_std <- tf$sqrt(tf$add(style_var, EPSILON))
  
  normalized_content_features <- style_std * ((content_features - content_mean)  / content_std) + style_mean
  
  normalized_content_features
}

target_features <- list(enc_c, enc_s) %>%
  layer_lambda(AdaIN)

decoder的部份主要是進行upsampling,將圖像還原成原本的大小。這裡使用REFLECT的padding方法,避免預測圖像產生白邊,另外,上採樣不用transposed convolution,為避免生成網格。最後一層layer沒有指定activation,使用預設的線性函數

#decoder
decoder <- function(inputs) {
  out <- inputs %>%
    #zero_padding
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    # block4_conv1
    layer_conv_2d(filters = 256, kernel_size = c(3 , 3), padding = "valid", activation = "relu") %>%
    layer_upsampling_2d(interpolation = 'bilinear') %>%
    #block3
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    layer_conv_2d(filters = 256, kernel_size = c(3 , 3), padding = "valid", activation = "relu") %>%
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    layer_conv_2d(filters = 256, kernel_size = c(3 , 3), padding = "valid", activation = "relu") %>%
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    layer_conv_2d(filters = 256, kernel_size = c(3 , 3), padding = "valid", activation = "relu") %>%
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    layer_conv_2d(filters = 128, kernel_size = c(3 , 3), padding = "valid", activation = "relu") %>%
    layer_upsampling_2d(interpolation = 'bilinear') %>%
    #block2
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    layer_conv_2d(filters = 128, kernel_size = c(3 , 3), padding = "valid", activation = "relu") %>%
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    layer_conv_2d(filters = 64, kernel_size = c(3 , 3), padding = "valid", activation = "relu") %>%
    layer_upsampling_2d(interpolation = 'bilinear') %>%
    #block1
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    layer_conv_2d(filters = 64, kernel_size = c(3 , 3), padding = "valid", activation = "relu") %>%
    layer_lambda(f = function(x) {tf$pad(x, tf$constant(list(c(0L, 0L), c(1L, 1L), c(1L, 1L), c(0L, 0L))) , mode = 'REFLECT')}) %>%
    layer_conv_2d(filters = 3, kernel_size = c(3 , 3), padding = "valid")
  out
}


dim <- k_int_shape(target_features)[-c(1, 2, 3)] %>% unlist()
decode_input <- layer_input(list(NULL, NULL, dim))
#?, ?, ?, 512
d_out <- decoder(decode_input)

在輸出圖像前,必須要注意zero-center和通道BGR的問題。前面有提到在輸入時各通道有扣除和置換的問題,這邊要先補回來與換回RGB,然後修剪成0至255的值。之後再置換成BGR及扣除平均數,為了下一個輸入(架構右邊的Encoder)。到這裡,完成了DECODER設定,輸入target_features,生成合成圖像generated_img

#generated_process
generated_process <- function(args){
  
  args <- k_concatenate(list(args[,,,1,drop=F] + 103.939,
                             args[,,,2,drop=F] + 116.779,
                             args[,,,3,drop=F] + 123.68),  axis = -1)
  #switch to RGB
  args <- k_reverse(args, axes = 4)
  args <- k_clip(args, 0 , 255)
  #switch to BGR
  args <- k_reverse(args, axes = 4)
  args <- k_concatenate(list(args[,,,1,drop=F] - 103.939,
                             args[,,,2,drop=F] - 116.779,
                             args[,,,3,drop=F] - 123.68),  axis = -1)
  args
}

d_out <- layer_lambda(d_out, generated_process)

#DECODER MODEL
DECODER <- keras_model(inputs = decode_input, outputs = d_out)

generated_img <- DECODER(target_features)

將合成圖像generated_img輸入ENCODE_CS,可得到其content、style的outputs。其中enc_gen是generated_img的content特徵…至此,該提取的特徵都具備好了,凍結ENCODE_CS參數權重,不必訓練更新。最後,串起建立STN model,以進行end-to-end訓練,其輸入分別為內容圖像及風格圖像,輸出為合成圖像

generated_encoded <- ENCODE_CS(generated_img)

enc_gen <- generated_encoded[[4]]

freeze_weights(ENCODE_CS)

#Style Transfer Network 
STN <- keras_model(list(content_input, style_input), generated_img)

有2種,content loss 和 style loss,即架構圖最右的Lc和Ls。content loss是計算target_features和enc_gen的MSE加總。style loss也是一樣,計算其對應layers輸出(有4層)的誤差加總

#content loss
content_loss <- function(y_true, y_pred) {
  
  c_loss <- tf$reduce_sum(tf$reduce_mean(tf$square(enc_gen - target_features), axis=c(1L, 2L)))

  c_loss
}
#style loss
style_loss <- function(y_true, y_pred) {
  
  s_loss = 0
  for(k in 1:length(generated_encoded)){
    
    generated <- generated_encoded[[k]]
    style <- style_encoded[[k]]
    
    c(gen_mean, gen_var) %<-% tf$nn$moments(generated, axes = c(1L, 2L), keep_dims = T)
    c(s_mean, s_var) %<-% tf$nn$moments(style, axes = c(1L, 2L), keep_dims = T)
    
    gen_std <- tf$sqrt(tf$add(gen_var, EPSILON))
    s_std <- tf$sqrt(tf$add(s_var, EPSILON))
    
    l2_mean <- tf$reduce_sum(tf$square(gen_mean - s_mean))
    l2_std <- tf$reduce_sum(tf$square(gen_std - s_std))
    
    s_loss <- s_loss + (l2_mean + l2_std)
  }
  s_loss
}

total_loss <- function(y_true, y_pred) {
  
  # content loss
  c_loss <-  content_loss(y_true, y_pred)
  # style loss
  s_loss <-  style_loss(y_true, y_pred)
  # total loss
  loss <- c_loss + STYLE_WEIGHT * s_loss
  
  loss
}

優化器設定及model compile

#compile
STN %>% compile(
  optimizer = optimizer_adam(lr = 0.0001, decay = (0.0001/EPOCHS)),
  loss = total_loss,
  metrics = list(custom_metric('content_loss' , content_loss), custom_metric('style_loss' , style_loss))
)

由於圖像數量龐大,不適合預先載入記億體,為加快I/O讀取速度、搭配data_generator,利用平行運算…

# doParallel
cl <- makePSOCKcluster(4)
registerDoParallel(cl)

clusterEvalQ(cl, {
  library(parallel)
  library(doParallel)
  library(abind)
  library(reticulate)
  library(keras)
  
  IMAGE_H         = 256L
  IMAGE_W         = 256L
  BATCH_SIZE      = 10L
  STYLE_WEIGHT    = 2.5
  EPSILON         = 1e-5
  EPOCHS          = 10
})

callbacks_list <- list(
  callback_tensorboard(log_dir = TSB_PATH, batch_size = BATCH_SIZE),
  callback_early_stopping(monitor = "val_loss",
                          min_delta = 0.0001, #less than min_delta will count as no improvement.
                          patience = 15,
                          verbose = 1,
                          mode = "min"),
  callback_reduce_lr_on_plateau(monitor = "val_loss",
                                factor = 0.1,
                                min_delta = 0.0001,
                                patience = 5,
                                verbose = 1,
                                mode = "min"),
  callback_model_checkpoint(filepath = file.path(SAVE_PATH,'{epoch:03d}.h5'),
                            monitor = "val_loss",
                            save_best_only = TRUE,
                            save_weights_only = TRUE,
                            mode = "min",
                            save_freq = NULL)
  )

STN %>% fit_generator(
  generator = train_iterator,
  steps_per_epoch = as.integer(length(content_filenames) / BATCH_SIZE),
  epochs = EPOCHS,
  validation_data = test_iterator,
  validation_steps = as.integer(length(content_filenames_test) / BATCH_SIZE),
  callbacks = callbacks_list
  )

stopCluster(cl)
gc()

訓練1個Epoch約2小時多一些…(使用1080 gpu)


以下是隨機挑選12張測試資料的預測結果…


隨拍照片的實測,不必重新訓練,可即時風格移轉…




這次風格移轉的實作,架構看起不難,但AdaIN的點子相當的棒,加上即時的風格轉換,運用上彈性靈活。經實際測試,發現如果訓練過多的Epochs,即使loss可以更低,但在視覺效果上看起來沒有較佳,有些細節會過度被放大(ex:容易出現網格),這部份實際原因還待查,或許在style layers新增權重說不定有幫助?!有興趣的朋友,可以試試看!對了,以上測試結果,是使用訓練完1個Epoch的權重,也就是說只要花2個小時多一些,就能達到以上效果!